{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/ec2-user/miniconda3/bin/python\n"
     ]
    }
   ],
   "source": [
    "!which python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"/home/ec2-user/nta/nupic.research\")\n",
    "sys.path.append(\"/home/ec2-user/nta/nupic.torch\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import torch\n",
    "\n",
    "from models import MultiHeadedSparseMLP, MultiHeadedDendriticMLP, DendriticMLP, SparseMLP\n",
    "\n",
    "from nupic.research.frameworks.pytorch.model_utils import count_nonzero_params\n",
    "\n",
    "from nupic.torch.modules import rezero_weights\n",
    "\n",
    "from nupic.research.frameworks.pytorch.models.common_models import StandardMLP\n",
    "\n",
    "from nupic.research.frameworks.dendrites.modules.dendritic_layers import BiasingDendriticLayer\n",
    "\n",
    "from copy import deepcopy\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data_with_context = torch.rand((8,21))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# start a regular MLP\n",
    "sparse_net = SparseMLP(\n",
    "    hidden_sizes=(2048,2048,2048),\n",
    "    input_size=21,\n",
    "    output_dim=7,\n",
    "    linear_weight_percent_on=(0.5, 0.5, 0.5),\n",
    "    linear_activity_percent_on=(0.1, 0.1, 0.1),\n",
    "    use_batch_norm=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3 µs, sys: 0 ns, total: 3 µs\n",
      "Wall time: 7.15 µs\n",
      "tensor([[-7.6795e-03,  3.7848e-03,  6.9775e-03, -9.3998e-03, -3.0156e-03,\n",
      "          3.1120e-03, -8.1030e-03],\n",
      "        [-7.4035e-03,  1.1602e-02,  7.0175e-03, -6.0808e-03, -5.8343e-03,\n",
      "         -2.1367e-03, -1.7176e-02],\n",
      "        [-3.5647e-03,  4.4708e-03,  1.0158e-02, -2.0132e-03, -1.0010e-02,\n",
      "         -3.2798e-03, -1.5986e-02],\n",
      "        [-7.7780e-03,  5.2119e-03,  4.1889e-03, -1.0919e-02, -6.2680e-03,\n",
      "          3.6875e-03, -1.1003e-02],\n",
      "        [-1.4235e-02,  1.8707e-02,  1.0515e-02, -6.6521e-03, -1.3724e-02,\n",
      "          6.5169e-06, -1.1407e-02],\n",
      "        [ 4.1190e-03,  1.6794e-02,  6.1752e-03, -1.4180e-03, -7.3284e-03,\n",
      "          4.6146e-04, -1.1511e-02],\n",
      "        [-1.0167e-02,  1.5454e-02,  5.4558e-03,  6.2702e-04, -7.3032e-03,\n",
      "         -5.3850e-03, -1.5711e-02],\n",
      "        [-1.2256e-02,  1.8944e-02,  8.3157e-03, -2.9518e-03, -2.3077e-03,\n",
      "         -2.5419e-04, -2.1103e-02]], grad_fn=<AddmmBackward>)\n"
     ]
    }
   ],
   "source": [
    "%time\n",
    "output = sparse_net(test_data_with_context)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# start a dendritic MLP\n",
    "dendrite_net = DendriticMLP(\n",
    "    hidden_sizes=(2048,2048,2048),\n",
    "    input_size=11,\n",
    "    output_dim=7,\n",
    "    kw=False,\n",
    "    relu=True,\n",
    "    kw_percent_on=1.,\n",
    "    dim_context=10,\n",
    "    num_segments=(30, 30, 30),\n",
    "    sparsity=0.5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data_without_context = torch.rand((8,11))\n",
    "test_context = torch.rand((8,10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1 µs, sys: 0 ns, total: 1 µs\n",
      "Wall time: 6.2 µs\n",
      "tensor([[ 0.0021,  0.0039,  0.0177, -0.0183, -0.0049, -0.0017,  0.0040],\n",
      "        [ 0.0029,  0.0058,  0.0173, -0.0170, -0.0067, -0.0014,  0.0042],\n",
      "        [ 0.0038,  0.0043,  0.0166, -0.0185, -0.0061, -0.0023,  0.0030],\n",
      "        [ 0.0045,  0.0052,  0.0177, -0.0180, -0.0056, -0.0002,  0.0041],\n",
      "        [ 0.0048,  0.0048,  0.0191, -0.0194, -0.0071, -0.0029,  0.0046],\n",
      "        [ 0.0026,  0.0054,  0.0159, -0.0190, -0.0079, -0.0015,  0.0039],\n",
      "        [ 0.0029,  0.0042,  0.0184, -0.0175, -0.0064, -0.0010,  0.0042],\n",
      "        [ 0.0031,  0.0040,  0.0177, -0.0184, -0.0066, -0.0005,  0.0051]],\n",
      "       grad_fn=<AddmmBackward>)\n"
     ]
    }
   ],
   "source": [
    "%time\n",
    "output = dendrite_net(test_data_without_context, test_context)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# start a dendritic MLP\n",
    "dendrite_net2 = DendriticMLP(\n",
    "    hidden_sizes=(2048,2048,2048),\n",
    "    input_size=11,\n",
    "    output_dim=7,\n",
    "    kw=False,\n",
    "    relu=True,\n",
    "    kw_percent_on=1.,\n",
    "    dim_context=10,\n",
    "    num_segments=(30, 30, 30),\n",
    "    sparsity=0.5,\n",
    "    dendritic_layer_class=BiasingDendriticLayer\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3 µs, sys: 0 ns, total: 3 µs\n",
      "Wall time: 6.44 µs\n",
      "tensor([[ 0.2199,  0.0905,  0.1041,  0.1568,  0.1492, -0.5456, -0.2105],\n",
      "        [ 0.1484,  0.0395,  0.0559,  0.0640,  0.1299, -0.4030, -0.1294],\n",
      "        [ 0.1744,  0.0723,  0.1382,  0.0609,  0.0445, -0.3637, -0.2328],\n",
      "        [ 0.2002,  0.1502,  0.0076,  0.1395,  0.0607, -0.5386, -0.2957],\n",
      "        [ 0.1378,  0.0193,  0.1687,  0.1601,  0.1357, -0.3876, -0.1388],\n",
      "        [ 0.2569,  0.0463,  0.0882,  0.1333,  0.1584, -0.4526, -0.2106],\n",
      "        [ 0.1801,  0.1057,  0.1073,  0.0925,  0.1084, -0.4615, -0.2016],\n",
      "        [ 0.1095,  0.0235,  0.0601,  0.0100,  0.0649, -0.3919, -0.1522]],\n",
      "       grad_fn=<AddmmBackward>)\n"
     ]
    }
   ],
   "source": [
    "%time\n",
    "output = dendrite_net2(test_data_without_context, test_context)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "dense_net = StandardMLP(\n",
    "    input_size=21, \n",
    "    num_classes=7,\n",
    "    hidden_sizes=(2048, 2048, 2048)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3 µs, sys: 0 ns, total: 3 µs\n",
      "Wall time: 6.91 µs\n",
      "tensor([[ 0.0233,  0.0151, -0.0164,  0.0210, -0.0011,  0.0606, -0.0041],\n",
      "        [ 0.0246,  0.0303, -0.0120,  0.0199,  0.0279,  0.0658,  0.0231],\n",
      "        [ 0.0072,  0.0339, -0.0156,  0.0378,  0.0111,  0.0486,  0.0422],\n",
      "        [ 0.0227,  0.0251, -0.0251,  0.0223,  0.0169,  0.0477,  0.0414],\n",
      "        [ 0.0225,  0.0143, -0.0296,  0.0042,  0.0103,  0.0580,  0.0302],\n",
      "        [ 0.0086,  0.0306, -0.0218,  0.0060,  0.0151,  0.0698,  0.0412],\n",
      "        [-0.0063,  0.0374, -0.0150,  0.0115, -0.0043,  0.0467,  0.0028],\n",
      "        [ 0.0122,  0.0219, -0.0309,  0.0065, -0.0067,  0.0507,  0.0020]],\n",
      "       grad_fn=<AddmmBackward>)\n"
     ]
    }
   ],
   "source": [
    "%time\n",
    "output = dense_net(test_data_with_context)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# have a small forward and backward pass. time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "sparse_net2 = deepcopy(sparse_net)\n",
    "sparse_net2._hidden_base.linear1_kwinners = nn.ReLU()\n",
    "sparse_net2._hidden_base.linear2_kwinners = nn.ReLU()\n",
    "sparse_net2._hidden_base.linear3_kwinners = nn.ReLU()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3 µs, sys: 0 ns, total: 3 µs\n",
      "Wall time: 6.91 µs\n",
      "tensor([[-0.0147,  0.0132, -0.0067, -0.0027, -0.0110,  0.0012, -0.0065],\n",
      "        [-0.0120,  0.0159, -0.0032,  0.0031, -0.0192,  0.0010, -0.0133],\n",
      "        [-0.0158,  0.0137, -0.0030,  0.0055, -0.0234,  0.0052, -0.0161],\n",
      "        [-0.0184,  0.0147, -0.0010,  0.0041, -0.0177,  0.0026, -0.0091],\n",
      "        [-0.0155,  0.0175, -0.0022, -0.0006, -0.0155,  0.0019, -0.0113],\n",
      "        [-0.0149,  0.0190, -0.0019,  0.0033, -0.0158,  0.0053, -0.0092],\n",
      "        [-0.0163,  0.0217, -0.0012,  0.0052, -0.0145,  0.0042, -0.0154],\n",
      "        [-0.0188,  0.0232, -0.0073,  0.0059, -0.0090,  0.0011, -0.0090]],\n",
      "       grad_fn=<AddmmBackward>)\n"
     ]
    }
   ],
   "source": [
    "%time\n",
    "output = sparse_net2(test_data_with_context)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\")\n",
    "for model in [dense_net, sparse_net, dendrite_net, sparse_net2, dendrite_net2]:\n",
    "    model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !cat ~/nta/nupic.research/nupic/research/frameworks/dendrites/functional/apply_dendrites.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, device, epochs=1, batch_size=64, dendrites=False, backward_pass=False):\n",
    "    \n",
    "    input_dim = 11\n",
    "    context_dim = 10\n",
    "    output_dim = 7\n",
    "\n",
    "    optim = torch.optim.SGD(lr=0.01, params=model.parameters())\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    for _ in range(epochs):\n",
    "        target = torch.randn(batch_size, output_dim, device=device)\n",
    "\n",
    "        if dendrites:\n",
    "            data = torch.rand(batch_size, input_dim, device=device)\n",
    "            context = torch.rand(batch_size, context_dim, device=device)\n",
    "            output = model(data, context)\n",
    "        else:\n",
    "            data = torch.rand(batch_size, input_dim+context_dim, device=device)\n",
    "            output = model(data)\n",
    "\n",
    "        if backward_pass:\n",
    "            loss = loss_fn(output, target)\n",
    "            loss.backward()\n",
    "\n",
    "            optim.step()\n",
    "\n",
    "#             model.apply(rezero_weights)\n",
    "            optim.zero_grad()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.33 s, sys: 440 ms, total: 1.77 s\n",
      "Wall time: 1.77 s\n"
     ]
    }
   ],
   "source": [
    "%time train(dense_net, device, epochs=epochs, batch_size=1024)   # dense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.14 s, sys: 280 ms, total: 2.42 s\n",
      "Wall time: 2.41 s\n"
     ]
    }
   ],
   "source": [
    "%time train(sparse_net, device, epochs=epochs, batch_size=1024)  # kwinners - 50% slower"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1.07 s, sys: 420 ms, total: 1.49 s\n",
      "Wall time: 1.49 s\n"
     ]
    }
   ],
   "source": [
    "%time train(sparse_net2, device, epochs=epochs, batch_size=1024)  # no kwinners"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 15.8 s, sys: 8.78 s, total: 24.6 s\n",
      "Wall time: 24.6 s\n"
     ]
    }
   ],
   "source": [
    "%time train(dendrite_net, device, epochs=epochs, batch_size=1024, dendrites=True)  # gating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%time train(dendrite_net2, device, epochs=epochs, batch_size=1024, dendrites=True)  # bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_nonzero_params(dendrite_net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_nonzero_params(sparse_net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sparse_net2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dense_net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
